import glob
import inspect
import os
import shutil
import subprocess
from typing import List, Callable

import tensorflow as tf
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.python.compiler.tensorrt import trt_convert as trt
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_io
from tensorflow.python.platform import gfile
from tensorflow.python.tools import optimize_for_inference_lib

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"


def _compress(src_path: str, dst_path: str):
    """
    Compress source path into destination path

    :param src_path: (str) Source path
    :param dst_path: (str) Destination path
    """
    print('[*] Compressing...')
    shutil.make_archive(dst_path, 'zip', src_path)
    print('[*] Compressed the contents in: {}.zip'.format(dst_path))


def _print_input(func: Callable):
    """
    Decorator printing function name and args
    :param func: (Callable) Decorated function
    :return: Wrapped call
    """

    def wrapper(*args, **kwargs):
        """
        Print the name and arguments of a function

        :param args: Named arguments
        :param kwargs: Keyword arguments
        :return: Original function call
        """
        tf.logging.set_verbosity(tf.logging.ERROR)
        func_args = inspect.signature(func).bind(*args, **kwargs).arguments
        func_args_str = ''.join('\t{} = {!r}\n'.format(*item) for item in func_args.items())

        print('[*] Running \'{}\' with arguments:'.format(func.__qualname__))
        print(func_args_str[:-1])

        return func(*args, **kwargs)

    return wrapper


def _parse_placeholder_types(values: str):
    """
    Extracts placeholder types from a comma separate list.

    :param values: (str) Placeholder types
    :return: (List) Placeholder types
    """
    values = [int(value) for value in values.split(",")]
    return values if len(values) > 1 else values[0]


def _optimize_checkpoint_for_inference(graph_path: str,
                                       input_names: List[str],
                                       output_names: List[str]):
    """
    Removes Horovod and training related information from the graph

    :param graph_path: (str) Path to the graph.pbtxt file
    :param input_names: (str) Input node names
    :param output_names: (str) Output node names
    """

    print('[*] Optimizing graph for inference ...')

    input_graph_def = graph_pb2.GraphDef()
    with gfile.Open(graph_path, "rb") as f:
        data = f.read()
        text_format.Merge(data.decode("utf-8"), input_graph_def)

    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
        input_graph_def,
        input_names,
        output_names,
        _parse_placeholder_types(str(dtypes.float32.as_datatype_enum)),
        False)

    print('[*] Saving original graph in: {}'.format(graph_path + '.old'))
    shutil.move(graph_path, graph_path + '.old')

    print('[*] Writing down optimized graph ...')
    graph_io.write_graph(output_graph_def,
                         os.path.dirname(graph_path),
                         os.path.basename(graph_path))


@_print_input
def to_savedmodel(input_shape: str,
                  model_fn: Callable,
                  checkpoint_dir: str,
                  output_dir: str,
                  input_names: List[str],
                  output_names: List[str],
                  use_amp: bool,
                  use_xla: bool,
                  compress: bool):
    """
    Export checkpoint to Tensorflow savedModel

    :param input_shape: (str) Input shape to the model in format [batch, height, width, channels]
    :param model_fn: (Callable) Estimator's model_fn
    :param checkpoint_dir: (str) Directory where checkpoints are stored
    :param output_dir: (str) Output directory for storage of the generated savedModel
    :param input_names: (List[str]) Input node names
    :param output_names: (List[str]) Output node names
    :param use_amp: (bool )Enable TF-AMP
    :param use_xla: (bool) Enable XLA
    :param compress: (bool) Compress output
    """
    assert os.path.exists(checkpoint_dir), 'Path not found: {}'.format(checkpoint_dir)
    assert input_shape is not None, 'Input shape must be provided'

    _optimize_checkpoint_for_inference(os.path.join(checkpoint_dir, 'graph.pbtxt'), input_names, output_names)

    try:
        ckpt_path = os.path.splitext([p for p in glob.iglob(os.path.join(checkpoint_dir, '*.index'))][0])[0]
    except IndexError:
        raise ValueError('Could not find checkpoint in directory: {}'.format(checkpoint_dir))

    config_proto = tf.compat.v1.ConfigProto()

    config_proto.allow_soft_placement = True
    config_proto.log_device_placement = False
    config_proto.gpu_options.allow_growth = True
    config_proto.gpu_options.force_gpu_compatible = True

    if use_amp:
        os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"
    if use_xla:
        config_proto.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1

    run_config = tf.estimator.RunConfig(
        model_dir=None,
        tf_random_seed=None,
        save_summary_steps=1e9,  # disabled
        save_checkpoints_steps=None,
        save_checkpoints_secs=None,
        session_config=config_proto,
        keep_checkpoint_max=None,
        keep_checkpoint_every_n_hours=1e9,  # disabled
        log_step_count_steps=1e9,
        train_distribute=None,
        device_fn=None,
        protocol=None,
        eval_distribute=None,
        experimental_distribute=None
    )

    estimator = tf.estimator.Estimator(
        model_fn=model_fn,
        model_dir=ckpt_path,
        config=run_config,
        params={'dtype': tf.float16 if use_amp else tf.float32}
    )

    print('[*] Exporting the model ...')

    input_type = tf.float16 if use_amp else tf.float32

    def get_serving_input_receiver_fn():

        def serving_input_receiver_fn():
            features = tf.placeholder(dtype=input_type, shape=input_shape, name='input_tensor')

            return tf.estimator.export.TensorServingInputReceiver(features=features, receiver_tensors=features)

        return serving_input_receiver_fn

    export_path = estimator.export_saved_model(
        export_dir_base=output_dir,
        serving_input_receiver_fn=get_serving_input_receiver_fn(),
        checkpoint_path=ckpt_path
    )

    print('[*] Done! path: `%s`' % export_path.decode())

    if compress:
        _compress(export_path.decode(), os.path.join(output_dir, 'saved_model'))


@_print_input
def to_tf_trt(savedmodel_dir: str,
              output_dir: str,
              precision: str,
              feed_dict_fn: Callable,
              num_runs: int,
              output_tensor_names: List[str],
              compress: bool):
    """
    Export Tensorflow savedModel to TF-TRT

    :param savedmodel_dir: (str) Input directory containing a Tensorflow savedModel
    :param output_dir: (str) Output directory for storage of the generated TF-TRT exported model
    :param precision: (str) Desired precision of the network (FP32, FP16 or INT8)
    :param feed_dict_fn: (Callable) Input tensors for INT8 calibration. Model specific.
    :param num_runs: (int) Number of calibration runs.
    :param output_tensor_names: (List) Name of the output tensor for graph conversion. Model specific.
    :param compress: (bool) Compress output
    """
    if savedmodel_dir is None or not os.path.exists(savedmodel_dir):
        raise FileNotFoundError('savedmodel_dir not found: {}'.format(savedmodel_dir))

    if os.path.exists(output_dir):
        print('[*] Output dir \'{}\' is not empty. Cleaning up ...'.format(output_dir))
        shutil.rmtree(output_dir)

    print('[*] Converting model...')

    converter = trt.TrtGraphConverter(input_saved_model_dir=savedmodel_dir,
                                      precision_mode=precision)
    converter.convert()

    if precision == 'INT8':
        print('[*] Running INT8 calibration ...')

        converter.calibrate(fetch_names=output_tensor_names, num_runs=num_runs, feed_dict_fn=feed_dict_fn)

    converter.save(output_dir)

    print('[*] Done! TF-TRT saved_model stored in: `%s`' % output_dir)

    if compress:
        _compress('tftrt_saved_model', output_dir)


@_print_input
def to_onnx(input_dir: str, output_dir: str, compress: bool):
    """
    Convert Tensorflow savedModel to ONNX with tf2onnx

    :param input_dir: (str) Input directory with a Tensorflow savedModel
    :param output_dir: (str) Output directory where to store the ONNX version of the model
    :param compress: (bool) Compress output
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    file_name = os.path.join(output_dir, 'model.onnx')
    print('[*] Converting model...')

    ret = subprocess.call(['python', '-m', 'tf2onnx.convert',
                           '--saved-model', input_dir,
                           '--output', file_name],
                          stdout=open(os.devnull, 'w'),
                          stderr=subprocess.STDOUT)
    if ret > 0:
        raise RuntimeError('tf2onnx.convert has failed with error: {}'.format(ret))

    print('[*] Done! ONNX file stored in: %s' % file_name)

    if compress:
        _compress(output_dir, 'onnx_model')

